/*
 * Copyright 2022 CloudWeGo Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package server

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"html/template"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"path"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/cloudwego/hertz/internal/test/mock/binder"
	"github.com/cloudwego/hertz/internal/testutils"
	"github.com/cloudwego/hertz/pkg/app"
	c "github.com/cloudwego/hertz/pkg/app/client"
	"github.com/cloudwego/hertz/pkg/app/server/binding"
	"github.com/cloudwego/hertz/pkg/app/server/registry"
	"github.com/cloudwego/hertz/pkg/common/config"
	errs "github.com/cloudwego/hertz/pkg/common/errors"
	"github.com/cloudwego/hertz/pkg/common/test/assert"
	"github.com/cloudwego/hertz/pkg/common/test/mock"
	"github.com/cloudwego/hertz/pkg/common/utils"
	"github.com/cloudwego/hertz/pkg/network"
	"github.com/cloudwego/hertz/pkg/network/standard"
	"github.com/cloudwego/hertz/pkg/protocol"
	"github.com/cloudwego/hertz/pkg/protocol/consts"
	"github.com/cloudwego/hertz/pkg/protocol/http1/req"
	"github.com/cloudwego/hertz/pkg/protocol/http1/resp"
)

type routeEngine interface {
	IsRunning() bool
}

func waitEngineRunning(e routeEngine) {
	testutils.WaitEngineRunning(e)
}

func fullURL(ln net.Listener, p string) string {
	return "http://" + path.Join(ln.Addr().String(), p)
}

func TestHertz_Run(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	hertz := Default(WithListener(ln))
	hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) {
		time.Sleep(time.Second)
		path := ctx.Request.URI().PathOriginal()
		ctx.SetBodyString(string(path))
	})

	testint := uint32(0)
	hertz.Engine.OnShutdown = append(hertz.OnShutdown, func(ctx context.Context) {
		atomic.StoreUint32(&testint, 1)
	})

	assert.Assert(t, len(hertz.Handlers) == 1)

	go hertz.Spin()
	waitEngineRunning(hertz)

	hertz.Close()
	resp, err := http.Get(fullURL(ln, "/test"))
	assert.NotNil(t, err)
	assert.Nil(t, resp)
	assert.DeepEqual(t, uint32(0), atomic.LoadUint32(&testint))
}

func TestHertz_GracefulShutdown(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	handling := make(chan struct{})
	closing := make(chan struct{})
	engine := New(WithListener(ln))
	engine.GET("/test", func(c context.Context, ctx *app.RequestContext) {
		close(handling)
		<-closing
		path := ctx.Request.URI().PathOriginal()
		ctx.SetBodyString(string(path))
	})
	engine.GET("/test2", func(c context.Context, ctx *app.RequestContext) {})

	testint := uint32(0)
	testint2 := uint32(0)
	testint3 := uint32(0)
	engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) {
		atomic.StoreUint32(&testint, 1)
	})
	engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) {
		atomic.StoreUint32(&testint2, 2)
	})
	engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) {
		atomic.StoreUint32(&testint3, 3)
	})

	go engine.Spin()
	waitEngineRunning(engine)

	hc := http.Client{Timeout: time.Second}
	var err error
	var resp *http.Response
	ch := make(chan struct{})
	ch2 := make(chan struct{})
	go func() {
		ticker := time.NewTicker(10 * time.Millisecond)
		defer ticker.Stop()
		for range ticker.C {
			t.Logf("[%v]begin listening\n", time.Now())
			_, err2 := hc.Get(fullURL(ln, "/test2"))
			if err2 != nil {
				t.Logf("[%v]listening closed: %v", time.Now(), err2)
				ch2 <- struct{}{}
				break
			}
		}
	}()
	go func() {
		t.Logf("[%v]begin request\n", time.Now())
		resp, err = http.Get(fullURL(ln, "/test"))
		t.Logf("[%v]end request\n", time.Now())
		ch <- struct{}{}
	}()

	<-handling

	start := time.Now()
	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
	t.Logf("[%v]begin shutdown\n", start)
	engine.Shutdown(ctx)
	end := time.Now()
	t.Logf("[%v]end shutdown\n", end)

	close(closing)
	<-ch
	assert.Nil(t, err)
	assert.NotNil(t, resp)
	assert.DeepEqual(t, true, resp.Close)
	assert.DeepEqual(t, uint32(1), atomic.LoadUint32(&testint))
	assert.DeepEqual(t, uint32(2), atomic.LoadUint32(&testint2))
	assert.DeepEqual(t, uint32(3), atomic.LoadUint32(&testint3))

	<-ch2

	cancel()
}

func TestLoadHTMLGlob(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithMaxRequestBodySize(15), WithListener(ln))
	engine.Delims("{[{", "}]}")
	engine.LoadHTMLGlob("../../common/testdata/template/index.tmpl")
	engine.GET("/index", func(c context.Context, ctx *app.RequestContext) {
		ctx.HTML(consts.StatusOK, "index.tmpl", utils.H{
			"title": "Main website",
		})
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	resp, _ := http.Get(fullURL(ln, "/index"))
	assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
	b := make([]byte, 100)
	n, _ := resp.Body.Read(b)
	const expected = `<html><h1>Main website</h1></html>`

	assert.DeepEqual(t, expected, string(b[0:n]))
}

func TestLoadHTMLFiles(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithMaxRequestBodySize(15), WithListener(ln))
	engine.Delims("{[{", "}]}")
	engine.SetFuncMap(template.FuncMap{
		"formatAsDate": formatAsDate,
	})
	engine.LoadHTMLFiles("../../common/testdata/template/htmltemplate.html", "../../common/testdata/template/index.tmpl")

	engine.GET("/raw", func(c context.Context, ctx *app.RequestContext) {
		ctx.HTML(consts.StatusOK, "htmltemplate.html", map[string]interface{}{
			"now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC),
		})
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	resp, _ := http.Get(fullURL(ln, "/raw"))
	assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
	b := make([]byte, 100)
	n, _ := resp.Body.Read(b)
	assert.DeepEqual(t, "<h1>Date: 2017/07/01</h1>", string(b[0:n]))
}

func formatAsDate(t time.Time) string {
	year, month, day := t.Date()
	return fmt.Sprintf("%d/%02d/%02d", year, month, day)
}

// copied from router
var (
	default400Body   = []byte("Bad Request")
	requiredHostBody = []byte("missing required Host header")
)

func TestServer_Use(t *testing.T) {
	router := New()
	router.Use(func(c context.Context, ctx *app.RequestContext) {})
	assert.DeepEqual(t, 1, len(router.Handlers))
	router.Use(func(c context.Context, ctx *app.RequestContext) {})
	assert.DeepEqual(t, 2, len(router.Handlers))
}

func Test_getServerName(t *testing.T) {
	engine := New()
	assert.DeepEqual(t, []byte("hertz"), engine.GetServerName())
	ss := New()
	ss.Name = "test_name"
	assert.DeepEqual(t, []byte("test_name"), ss.GetServerName())
}

func TestServer_Run(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	hertz := New(WithListener(ln))
	hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) {
		path := ctx.Request.URI().PathOriginal()
		ctx.SetBodyString(string(path))
	})
	hertz.POST("/redirect", func(c context.Context, ctx *app.RequestContext) {
		ctx.Redirect(consts.StatusMovedPermanently, []byte(fullURL(ln, "/test")))
	})
	go hertz.Run()
	waitEngineRunning(hertz)

	resp, err := http.Get(fullURL(ln, "/test"))
	assert.Nil(t, err)
	assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
	b := make([]byte, 5)
	resp.Body.Read(b)
	assert.DeepEqual(t, "/test", string(b))

	resp, err = http.Get(fullURL(ln, "/foo"))
	assert.Nil(t, err)
	assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode)

	resp, err = http.Post(fullURL(ln, "/redirect"), "", nil)
	assert.Nil(t, err)
	assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
	b = make([]byte, 5)
	resp.Body.Read(b)
	assert.DeepEqual(t, "/test", string(b))

	ctx, cancel := context.WithTimeout(context.Background(), 0)
	defer cancel()
	_ = hertz.Shutdown(ctx)
}

func TestNotAbsolutePath(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithListener(ln))
	engine.POST("/", func(c context.Context, ctx *app.RequestContext) {
		ctx.Write(ctx.Request.Body())
	})
	engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
		ctx.Write(ctx.Request.Body())
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
	zr := mock.NewZeroCopyReader(s)

	ctx := app.NewContext(0)
	if err := req.Read(&ctx.Request, zr); err != nil {
		t.Fatalf("unexpected error: %s", err)
	}
	engine.ServeHTTP(context.Background(), ctx)
	assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode())
	assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body())

	s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
	zr = mock.NewZeroCopyReader(s)

	ctx = app.NewContext(0)
	if err := req.Read(&ctx.Request, zr); err != nil {
		t.Fatalf("unexpected error: %s", err)
	}
	engine.ServeHTTP(context.Background(), ctx)
	assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode())
	assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body())
}

func TestNotAbsolutePathWithRawPath(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithListener(ln), WithUseRawPath(true))
	const (
		MiddlewareKey   = "middleware_key"
		MiddlewareValue = "middleware_value"
	)
	engine.Use(func(c context.Context, ctx *app.RequestContext) {
		ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue)
	})
	engine.POST("/", func(c context.Context, ctx *app.RequestContext) {
	})
	engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
	zr := mock.NewZeroCopyReader(s)

	ctx := app.NewContext(0)
	if err := req.Read(&ctx.Request, zr); err != nil {
		t.Fatalf("unexpected error: %s", err)
	}
	engine.ServeHTTP(context.Background(), ctx)
	assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
	assert.DeepEqual(t, default400Body, ctx.Response.Body())
	gh := ctx.Response.Header.Get(MiddlewareKey)
	assert.DeepEqual(t, MiddlewareValue, gh)

	s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
	zr = mock.NewZeroCopyReader(s)

	ctx = app.NewContext(0)
	if err := req.Read(&ctx.Request, zr); err != nil {
		t.Fatalf("unexpected error: %s", err)
	}
	engine.ServeHTTP(context.Background(), ctx)
	assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
	assert.DeepEqual(t, default400Body, ctx.Response.Body())
	gh = ctx.Response.Header.Get(MiddlewareKey)
	assert.DeepEqual(t, MiddlewareValue, gh)
}

func TestNotValidHost(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithListener(ln))
	const (
		MiddlewareKey   = "middleware_key"
		MiddlewareValue = "middleware_value"
	)
	engine.Use(func(c context.Context, ctx *app.RequestContext) {
		ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue)
	})
	engine.POST("/", func(c context.Context, ctx *app.RequestContext) {
	})
	engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
	})

	s := "POST ?a=b HTTP/1.1\r\nHost: \r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
	zr := mock.NewZeroCopyReader(s)

	ctx := app.NewContext(0)
	if err := req.Read(&ctx.Request, zr); err != nil {
		t.Fatalf("unexpected error: %s", err)
	}
	engine.ServeHTTP(context.Background(), ctx)
	assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
	assert.DeepEqual(t, requiredHostBody, ctx.Response.Body())
	gh := ctx.Response.Header.Get(MiddlewareKey)
	assert.DeepEqual(t, MiddlewareValue, gh)

	s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
	zr = mock.NewZeroCopyReader(s)

	ctx = app.NewContext(0)
	if err := req.Read(&ctx.Request, zr); err != nil {
		t.Fatalf("unexpected error: %s", err)
	}
	engine.ServeHTTP(context.Background(), ctx)
	assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
	assert.DeepEqual(t, requiredHostBody, ctx.Response.Body())
	gh = ctx.Response.Header.Get(MiddlewareKey)
	assert.DeepEqual(t, MiddlewareValue, gh)
}

func TestWithBasePath(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithBasePath("/hertz"), WithListener(ln))
	engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	var r http.Request
	r.ParseForm()
	r.Form.Add("xxxxxx", "xxx")
	body := strings.NewReader(r.Form.Encode())
	resp, err := http.Post(fullURL(ln, "/hertz/test"), "application/x-www-form-urlencoded", body)
	assert.Nil(t, err)
	assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
}

func TestNotEnoughBodySize(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithMaxRequestBodySize(5), WithListener(ln))
	engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	var r http.Request
	r.ParseForm()
	r.Form.Add("xxxxxx", "xxx")
	body := strings.NewReader(r.Form.Encode())
	resp, err := http.Post(fullURL(ln, "/test"), "application/x-www-form-urlencoded", body)
	assert.Nil(t, err)
	assert.DeepEqual(t, 413, resp.StatusCode)
	bodyBytes, _ := ioutil.ReadAll(resp.Body)
	assert.DeepEqual(t, "Request Entity Too Large", string(bodyBytes))
}

func TestEnoughBodySize(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	engine := New(WithMaxRequestBodySize(15), WithListener(ln))
	engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
	})
	go engine.Run()
	defer func() {
		engine.Close()
	}()
	waitEngineRunning(engine)

	var r http.Request
	r.ParseForm()
	r.Form.Add("xxxxxx", "xxx")
	body := strings.NewReader(r.Form.Encode())
	resp, _ := http.Post(fullURL(ln, "/test"), "application/x-www-form-urlencoded", body)
	assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
}

func TestRequestCtxHijack(t *testing.T) {
	hijackStartCh := make(chan struct{})
	hijackStopCh := make(chan struct{})
	engine := New()
	engine.Init()

	engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) {
		if ctx.Hijacked() {
			t.Error("connection mustn't be hijacked")
		}
		ctx.Hijack(func(c network.Conn) {
			<-hijackStartCh

			b := make([]byte, 1)
			// ping-pong echo via hijacked conn
			for {
				n, err := c.Read(b)
				if n != 1 {
					if err == io.EOF {
						close(hijackStopCh)
						return
					}
					if err != nil {
						t.Errorf("unexpected error: %s", err)
					}
					t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
				}
				if _, err = c.Write(b); err != nil {
					t.Errorf("unexpected error when writing data: %s", err)
				}
			}
		})
		if !ctx.Hijacked() {
			t.Error("connection must be hijacked")
		}
		ctx.Data(consts.StatusOK, "foo/bar", []byte("hijack it!"))
	})

	hijackedString := "foobar baz hijacked!!!"

	c := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" + hijackedString)

	ch := make(chan error)
	go func() {
		ch <- engine.Serve(context.Background(), c)
	}()

	time.Sleep(100 * time.Millisecond)

	close(hijackStartCh)

	if err := <-ch; err != nil {
		if !errors.Is(err, errs.ErrHijacked) {
			t.Fatalf("Unexpected error from serveConn: %s", err)
		}
	}
	verifyResponse(t, c.WriterRecorder(), consts.StatusOK, "foo/bar", "hijack it!")

	select {
	case <-hijackStopCh:
	case <-time.After(100 * time.Millisecond):
		t.Fatal("timeout")
	}

	zw := c.WriterRecorder()
	data, err := zw.ReadBinary(zw.Len())
	if err != nil {
		t.Fatalf("Unexpected error when reading remaining data: %s", err)
	}
	if string(data) != hijackedString {
		t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString)
	}
}

func verifyResponse(t *testing.T, zr network.Reader, expectedStatusCode int, expectedContentType, expectedBody string) {
	var r protocol.Response
	if err := resp.Read(&r, zr); err != nil {
		t.Fatalf("Unexpected error when parsing response: %s", err)
	}

	if !bytes.Equal(r.Body(), []byte(expectedBody)) {
		t.Fatalf("Unexpected body %q. Expected %q", r.Body(), []byte(expectedBody))
	}
	verifyResponseHeader(t, &r.Header, expectedStatusCode, len(r.Body()), expectedContentType, "")
}

func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding string) {
	if h.StatusCode() != expectedStatusCode {
		t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode)
	}
	if h.ContentLength() != expectedContentLength {
		t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength)
	}
	if string(h.ContentType()) != expectedContentType {
		t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType)
	}
	if string(h.ContentEncoding()) != expectedContentEncoding {
		t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding)
	}
}

func TestParamInconsist(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	mapS := sync.Map{}
	h := New(WithListener(ln))
	h.GET("/:label", func(c context.Context, ctx *app.RequestContext) {
		label := ctx.Param("label")
		x, _ := mapS.LoadOrStore(label, label)
		labelString := x.(string)
		if label != labelString {
			t.Errorf("unexpected label: %s, expected return label: %s", label, labelString)
		}
	})
	go h.Run()
	waitEngineRunning(h)

	client, _ := c.NewClient()
	wg := sync.WaitGroup{}
	tr := func() {
		defer wg.Done()
		for i := 0; i < 500; i++ {
			client.Get(context.Background(), nil, fullURL(ln, "/test1"))
		}
	}
	ti := func() {
		defer wg.Done()
		for i := 0; i < 500; i++ {
			client.Get(context.Background(), nil, fullURL(ln, "/test2"))
		}
	}

	for i := 0; i < 30; i++ {
		go tr()
		go ti()
		wg.Add(2)
	}
	wg.Wait()
}

func TestDuplicateReleaseBodyStream(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(WithStreamBody(true), WithListener(ln))
	h.POST("/test", func(ctx context.Context, c *app.RequestContext) {
		stream := c.RequestBodyStream()
		c.Response.SetBodyStream(stream, -1)
	})
	go h.Spin()
	waitEngineRunning(h)

	client, _ := c.NewClient(c.WithMaxConnsPerHost(1000000), c.WithDialTimeout(time.Minute))
	bodyBytes := make([]byte, 102388)
	index := 0
	letterBytes := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
	for i := 0; i < 102388; i++ {
		bodyBytes[i] = letterBytes[index]
		if i%1969 == 0 && i != 0 {
			index = index + 1
		}
	}
	body := string(bodyBytes)

	wg := sync.WaitGroup{}
	testFunc := func() {
		defer wg.Done()
		r := protocol.NewRequest("POST", fullURL(ln, "/test"), nil)
		r.SetBodyString(body)
		resp := protocol.AcquireResponse()
		err := client.Do(context.Background(), r, resp)
		if err != nil {
			t.Errorf("unexpected error: %s", err.Error())
		}
		if body != string(resp.Body()) {
			t.Errorf("unequal body")
		}
	}

	for i := 0; i < 10; i++ {
		wg.Add(1)
		go testFunc()
	}
	wg.Wait()
}

func TestServiceRegisterFailed(t *testing.T) {
	t.Parallel() // slow test, make it parallel

	ln := testutils.NewTestListener(t)
	defer ln.Close()
	mockRegErr := errors.New("mock register error")
	var rCount int32
	var drCount int32
	mockRegistry := MockRegistry{
		RegisterFunc: func(info *registry.Info) error {
			atomic.AddInt32(&rCount, 1)
			return mockRegErr
		},
		DeregisterFunc: func(info *registry.Info) error {
			atomic.AddInt32(&drCount, 1)
			return nil
		},
	}
	var opts []config.Option
	opts = append(opts, WithRegistry(mockRegistry, nil))
	opts = append(opts, WithListener(ln))
	srv := New(opts...)
	srv.Spin()
	assert.Assert(t, atomic.LoadInt32(&rCount) == 1)
}

func TestServiceDeregisterFailed(t *testing.T) {
	t.Parallel() // slow test, make it parallel

	ln := testutils.NewTestListener(t)
	defer ln.Close()
	mockDeregErr := errors.New("mock deregister error")

	var wg sync.WaitGroup
	wg.Add(2) // RegisterFunc && DeregisterFunc
	var rCount int32
	var drCount int32
	mockRegistry := MockRegistry{
		RegisterFunc: func(info *registry.Info) error {
			defer wg.Done()
			atomic.AddInt32(&rCount, 1)
			return nil
		},
		DeregisterFunc: func(info *registry.Info) error {
			defer wg.Done()
			atomic.AddInt32(&drCount, 1)
			return mockDeregErr
		},
	}

	var opts []config.Option
	opts = append(opts, WithRegistry(mockRegistry, nil))
	opts = append(opts, WithListener(ln))
	srv := New(opts...)
	go srv.Spin()
	waitEngineRunning(srv)

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()
	_ = srv.Shutdown(ctx)

	wg.Wait()
	assert.Assert(t, atomic.LoadInt32(&rCount) == 1)
	assert.Assert(t, atomic.LoadInt32(&drCount) == 1)
}

func TestServiceRegistryInfo(t *testing.T) {
	t.Parallel() // slow test, make it parallel

	ln := testutils.NewTestListener(t)
	defer ln.Close()
	registryInfo := &registry.Info{
		Weight:      100,
		Tags:        map[string]string{"aa": "bb"},
		ServiceName: "hertz.api.test",
	}
	checkInfo := func(info *registry.Info) {
		assert.Assert(t, info.Weight == registryInfo.Weight)
		assert.Assert(t, info.ServiceName == "hertz.api.test")
		assert.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags)
		assert.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags)
	}

	var wg sync.WaitGroup
	wg.Add(2) // RegisterFunc && DeregisterFunc
	var rCount int32
	var drCount int32
	mockRegistry := MockRegistry{
		RegisterFunc: func(info *registry.Info) error {
			defer wg.Done()
			checkInfo(info)
			atomic.AddInt32(&rCount, 1)
			return nil
		},
		DeregisterFunc: func(info *registry.Info) error {
			defer wg.Done()
			checkInfo(info)
			atomic.AddInt32(&drCount, 1)
			return nil
		},
	}
	var opts []config.Option
	opts = append(opts, WithRegistry(mockRegistry, registryInfo))
	opts = append(opts, WithListener(ln))
	srv := New(opts...)
	go srv.Spin()
	waitEngineRunning(srv)

	ctx, cancel := context.WithTimeout(context.Background(), 0)
	defer cancel()
	_ = srv.Shutdown(ctx)
	wg.Wait()
	assert.Assert(t, atomic.LoadInt32(&rCount) == 1)
	assert.Assert(t, atomic.LoadInt32(&drCount) == 1)
}

func TestServiceRegistryNoInitInfo(t *testing.T) {
	t.Parallel() // slow test, make it parallel

	ln := testutils.NewTestListener(t)
	defer ln.Close()
	checkInfo := func(info *registry.Info) {
		assert.Assert(t, info == nil)
	}

	var wg sync.WaitGroup
	wg.Add(2) // RegisterFunc && DeregisterFunc
	var rCount int32
	var drCount int32
	mockRegistry := MockRegistry{
		RegisterFunc: func(info *registry.Info) error {
			defer wg.Done()
			checkInfo(info)
			atomic.AddInt32(&rCount, 1)
			return nil
		},
		DeregisterFunc: func(info *registry.Info) error {
			defer wg.Done()
			checkInfo(info)
			atomic.AddInt32(&drCount, 1)
			return nil
		},
	}
	var opts []config.Option
	opts = append(opts, WithRegistry(mockRegistry, nil))
	opts = append(opts, WithListener(ln))
	srv := New(opts...)
	go srv.Spin()
	waitEngineRunning(srv)

	ctx, cancel := context.WithTimeout(context.Background(), 0)
	defer cancel()
	_ = srv.Shutdown(ctx)
	wg.Wait()
	assert.Assert(t, atomic.LoadInt32(&rCount) == 1)
	assert.Assert(t, atomic.LoadInt32(&drCount) == 1)
}

type testTracer struct{}

func (t testTracer) Start(ctx context.Context, c *app.RequestContext) context.Context {
	value := 0
	if v := ctx.Value("testKey"); v != nil {
		value = v.(int)
		value++
	}
	//nolint:staticcheck // SA1029 no built-in type string as key
	return context.WithValue(ctx, "testKey", value)
}

func (t testTracer) Finish(ctx context.Context, c *app.RequestContext) {}

func TestReuseCtx(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(WithTracer(testTracer{}), WithListener(ln))
	h.GET("/ping", func(ctx context.Context, c *app.RequestContext) {
		assert.DeepEqual(t, 0, ctx.Value("testKey").(int))
	})

	go h.Spin()
	waitEngineRunning(h)

	for i := 0; i < 1000; i++ {
		_, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping"))
		assert.Nil(t, err)
	}
}

type CloseWithoutResetBuffer interface {
	CloseNoResetBuffer() error
}

func TestOnprepare(t *testing.T) {
	ln1 := testutils.NewTestListener(t)
	defer ln1.Close()
	h1 := New(
		WithListener(ln1),
		WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
			b, err := conn.Peek(3)
			assert.Nil(t, err)
			assert.DeepEqual(t, string(b), "GET")
			if c, ok := conn.(CloseWithoutResetBuffer); ok {
				c.CloseNoResetBuffer()
			} else {
				conn.Close()
			}
			return ctx
		}))
	h1.GET("/ping", func(ctx context.Context, c *app.RequestContext) {
		c.JSON(consts.StatusOK, utils.H{"ping": "pong"})
	})

	go h1.Spin()
	waitEngineRunning(h1)

	_, _, err := c.Get(context.Background(), nil, fullURL(ln1, "/ping"))
	assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error())

	ln2 := testutils.NewTestListener(t)
	defer ln2.Close()
	h2 := New(
		WithOnAccept(func(conn net.Conn) context.Context {
			conn.Close()
			return context.Background()
		}),
		WithListener(ln2))
	h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) {
		c.JSON(consts.StatusOK, utils.H{"ping": "pong"})
	})
	go h2.Spin()
	waitEngineRunning(h2)

	_, _, err = c.Get(context.Background(), nil, fullURL(ln2, "/ping"))
	if err == nil {
		t.Fatalf("err should not be nil")
	}

	ln3 := testutils.NewTestListener(t)
	defer ln3.Close()
	var h3 *Hertz
	h3 = New(
		WithOnAccept(func(conn net.Conn) context.Context {
			assert.DeepEqual(t, conn.LocalAddr().String(), ln3.Addr().String())
			return context.Background()
		}),
		WithListener(ln3),
		WithTransport(standard.NewTransporter))
	h3.GET("/ping", func(ctx context.Context, c *app.RequestContext) {
		c.JSON(consts.StatusOK, utils.H{"ping": "pong"})
	})
	go h3.Spin()
	waitEngineRunning(h3)

	c.Get(context.Background(), nil, fullURL(ln3, "/ping"))
}

type lockBuffer struct {
	sync.Mutex
	b bytes.Buffer
}

func (l *lockBuffer) Write(p []byte) (int, error) {
	l.Lock()
	defer l.Unlock()
	return l.b.Write(p)
}

func (l *lockBuffer) String() string {
	l.Lock()
	defer l.Unlock()
	return l.b.String()
}

func TestHertzDisableHeaderNamesNormalizing(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithDisableHeaderNamesNormalizing(true),
	)
	headerName := "CASE-senSITive-HEAder-NAME"
	headerValue := "foobar-baz"
	succeed := false
	h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
		ctx.VisitAllHeaders(func(key, value []byte) {
			if string(key) == headerName && string(value) == headerValue {
				succeed = true
				return
			}
		})
		if !succeed {
			t.Fatalf("DisableHeaderNamesNormalizing failed")
		} else {
			ctx.Header(headerName, headerValue)
		}
	})

	go h.Spin()
	waitEngineRunning(h)

	cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true))

	r := protocol.NewRequest("GET", fullURL(ln, "/test"), nil)
	r.Header.DisableNormalizing()
	r.Header.Set(headerName, headerValue)
	res := protocol.AcquireResponse()
	err := cli.Do(context.Background(), r, res)
	assert.Nil(t, err)
	assert.DeepEqual(t, headerValue, res.Header.Get(headerName))
}

func TestBindConfig(t *testing.T) {
	type Req struct {
		A int `query:"a"`
	}
	bindConfig := binding.NewBindConfig()
	bindConfig.LooseZeroMode = true
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithBindConfig(bindConfig))
	h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req Req
		err := ctx.BindAndValidate(&req)
		if err != nil {
			t.Fatal("unexpected error")
		}
	})

	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	_, err := hc.Get(fullURL(ln, "/bind?a="))
	assert.Nil(t, err)

	bindConfig = binding.NewBindConfig()
	bindConfig.LooseZeroMode = false
	ln2 := testutils.NewTestListener(t)
	defer ln2.Close()
	h2 := New(
		WithListener(ln2),
		WithBindConfig(bindConfig))
	h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req Req
		err := ctx.BindAndValidate(&req)
		if err == nil {
			t.Fatal("expect an error")
		}
	})

	go h2.Spin()
	waitEngineRunning(h2)

	_, err = hc.Get(fullURL(ln2, "/bind?a="))
	assert.Nil(t, err)
	time.Sleep(100 * time.Millisecond)
}

func TestCustomBinder(t *testing.T) {
	type Req struct {
		A int `query:"a"`
	}
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithCustomBinder(binder.NewBinderWithValidateError(errors.New("test binder"))))
	h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req Req
		err := ctx.BindAndValidate(&req)
		if err == nil {
			t.Fatal("expect an error")
		}
		assert.DeepEqual(t, "test binder", err.Error())
	})

	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	_, err := hc.Get(fullURL(ln, "/bind?a="))
	assert.Nil(t, err)
	time.Sleep(100 * time.Millisecond)
}

func TestValidateConfigRegValidateFunc(t *testing.T) {
	type Req struct {
		A int `query:"a" vd:"f($)"`
	}
	validateConfig := &binding.ValidateConfig{}
	validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error {
		return fmt.Errorf("test validator")
	})
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(WithListener(ln))
	h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req Req
		err := ctx.BindAndValidate(&req)
		if err == nil {
			t.Fatal("expect an error")
		}
		assert.DeepEqual(t, "test validator", err.Error())
	})

	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	_, err := hc.Get(fullURL(ln, "/bind?a=2"))
	assert.Nil(t, err)
	time.Sleep(100 * time.Millisecond)
}

func TestCustomValidator(t *testing.T) {
	type Req struct {
		A int `query:"a" vd:"f($)"`
	}
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithCustomValidatorFunc(func(_ *protocol.Request, _ interface{}) error {
			return errors.New("test mock validator")
		}))
	h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req Req
		err := ctx.BindAndValidate(&req)
		if err == nil {
			t.Fatal("expect an error")
		}
		assert.DeepEqual(t, "test mock validator", err.Error())
	})

	go h.Spin()
	time.Sleep(100 * time.Millisecond)
	hc := http.Client{Timeout: time.Second}
	_, err := hc.Get(fullURL(ln, "/bind?a=2"))
	assert.Nil(t, err)
	time.Sleep(100 * time.Millisecond)
}

type ValidateError struct {
	ErrType, FailField, Msg string
}

// Error implements error interface.
func (e *ValidateError) Error() string {
	if e.Msg != "" {
		return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg
	}
	return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid"
}

func TestValidateConfigSetSetErrorFactory(t *testing.T) {
	type TestValidate struct {
		B int `query:"b" vd:"$>100"`
	}
	CustomValidateErrFunc := func(failField, msg string) error {
		err := ValidateError{
			ErrType:   "validateErr",
			FailField: "[validateFailField]: " + failField,
			Msg:       "[validateErrMsg]: " + msg,
		}

		return &err
	}
	validateConfig := binding.NewValidateConfig()
	validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc)
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithValidateConfig(validateConfig))
	h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req TestValidate
		err := ctx.BindAndValidate(&req)
		if err == nil {
			t.Fatal("expect an error")
		}
		assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error())
	})

	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	_, err := hc.Get(fullURL(ln, "/bind?b=1"))
	assert.Nil(t, err)
	time.Sleep(100 * time.Millisecond)
}

func TestValidateConfigAndBindConfig(t *testing.T) {
	type Req struct {
		A int `query:"a" vt:"$>=0&&$<=130"`
	}
	validateConfig := binding.NewValidateConfig()
	validateConfig.ValidateTag = "vt"
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithValidateConfig(validateConfig))
	h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
		var req Req
		err := ctx.BindAndValidate(&req)
		if err == nil {
			t.Fatal("expect an error")
		}
		t.Log(err)
	})

	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	_, err := hc.Get(fullURL(ln, "/bind?a=135"))
	assert.Nil(t, err)
	time.Sleep(100 * time.Millisecond)
}

func TestWithDisableDefaultDate(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithDisableDefaultDate(true),
	)
	h.GET("/", func(_ context.Context, c *app.RequestContext) {})
	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	r, _ := hc.Get(fullURL(ln, "")) //nolint:errcheck
	assert.DeepEqual(t, "", r.Header.Get("Date"))
}

func TestWithDisableDefaultContentType(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := New(
		WithListener(ln),
		WithDisableDefaultContentType(true),
	)
	h.GET("/", func(_ context.Context, c *app.RequestContext) {})
	go h.Spin()
	waitEngineRunning(h)

	hc := http.Client{Timeout: time.Second}
	r, _ := hc.Get(fullURL(ln, "")) //nolint:errcheck
	assert.DeepEqual(t, "", r.Header.Get("Content-Type"))
}

func TestWithSenseClientDisconnection(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	var closeFlag int32
	h := New(WithListener(ln), WithSenseClientDisconnection(true))
	h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
		assert.DeepEqual(t, "aa", string(ctx.Host()))
		ch := make(chan struct{})
		select {
		case <-c.Done():
			atomic.StoreInt32(&closeFlag, 1)
			assert.DeepEqual(t, context.Canceled, c.Err())
		case <-ch:
		}
	})
	go h.Spin()
	waitEngineRunning(h)

	con, err := net.Dial("tcp", ln.Addr().String())
	assert.Nil(t, err)
	_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
	assert.Nil(t, err)
	time.Sleep(20 * time.Millisecond)
	assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
	assert.Nil(t, con.Close())
	time.Sleep(20 * time.Millisecond)
	assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}

func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	var closeFlag int32
	h := New(WithListener(ln), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
		return ctx
	}))
	h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
		assert.DeepEqual(t, "aa", string(ctx.Host()))
		ch := make(chan struct{})
		select {
		case <-c.Done():
			atomic.StoreInt32(&closeFlag, 1)
			assert.DeepEqual(t, context.Canceled, c.Err())
		case <-ch:
		}
	})
	go h.Spin()
	waitEngineRunning(h)

	con, err := net.Dial("tcp", ln.Addr().String())
	assert.Nil(t, err)
	_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
	assert.Nil(t, err)
	time.Sleep(20 * time.Millisecond)
	assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
	assert.Nil(t, con.Close())
	time.Sleep(20 * time.Millisecond)
	assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}

func TestServerReturns413And431OnSizeLimits(t *testing.T) {
	ln := testutils.NewTestListener(t)
	defer ln.Close()
	h := Default(WithListener(ln), WithMaxHeaderBytes(500), WithMaxRequestBodySize(1000))

	h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
		ctx.String(consts.StatusOK, "success")
	})
	h.POST("/test", func(c context.Context, ctx *app.RequestContext) {
		ctx.String(consts.StatusOK, "success")
	})

	go h.Spin()
	waitEngineRunning(h)
	defer h.Shutdown(context.Background())

	addr := ln.Addr().String()
	client := &http.Client{Timeout: 2 * time.Second}

	// Test 431 - Request Header Fields Too Large
	req, _ := http.NewRequest("GET", fmt.Sprintf("http://%s/test", addr), nil)
	req.Header.Set("Large-Header", strings.Repeat("x", 501)) // Exceeds 500 byte limit

	resp, err := client.Do(req)
	assert.Nil(t, err)
	resp.Body.Close()

	// If we get a response, it should be 431
	assert.DeepEqual(t, resp.StatusCode, 431)

	// Test 413 - Request Entity Too Large
	largeBody := strings.NewReader(strings.Repeat("x", 1001)) // Exceeds 1000 byte limit
	req2, _ := http.NewRequest("POST", fmt.Sprintf("http://%s/test", addr), largeBody)

	resp2, err2 := client.Do(req2)
	assert.Nil(t, err2)
	resp2.Body.Close()

	// Should return 413
	assert.DeepEqual(t, resp2.StatusCode, 413)
}
